use crate::common::{
    symcodec::SymbolCodec,
    symbol::SymbolWithin8Bits,
    bitio::{BitsRead, BitsWrite}
};
use core::slice;
use std::{io::{Error as IOError, Read, Write, Seek, SeekFrom}, marker::PhantomData};

/// RLE codec with adaptive count length (up to 32 bits).
/// 
/// This variance leaves single occurances of symbols intact and
/// deals with only two or more occurances of symbols.
pub struct RLECodec<S> {
    __unused__: PhantomData<S>
}

impl<S: SymbolWithin8Bits> RLECodec<S> {
    pub fn new() -> Self {
        RLECodec { 
            __unused__: PhantomData
        }
    }
}

impl<S: SymbolWithin8Bits> SymbolCodec for RLECodec<S> {
    fn encode<R, W>(
        &self, 
        i_byte_reader: &mut R,
        o_byte_writer: &mut W, 
        symbol_read_limit: u32
    ) -> Result<u32, IOError>
    where 
        R: Read, 
        W: Write + Seek
    {
        // this implementation accumulates counts in memory and writes counts
        // after all symbols

        // encoded file structure:
        // |symbol len in bits|                        1 byte
        // |counts field len in bits|                  1 byte
        // |number of symbols encoded|                 4 bytes
        // |the offset of the start of counts region|  4 bytes
        // |the length in bytes of the counts region|  4 bytes
        // |symbol string|
        // |count records|

        let symbol_bits = S::bits_needed();

        o_byte_writer.write_all(slice::from_ref(&symbol_bits))?;
        o_byte_writer.seek(SeekFrom::Current(13))?;

        let mut bitsreader = BitsRead::create_with(i_byte_reader);
        let mut bitswriter = BitsWrite::create_with(o_byte_writer);

        let mut prev = 0u8;
        let mut b = 0u8;

        if !bitsreader.read_bits_8_8_unchecked(&mut prev, symbol_bits)? {
            return Ok(0);
        }

        let mut count = 1u32;
        let mut max_count = 0u32;
        let mut passed = 1u32;
        let mut v: Vec<u32> = Vec::with_capacity(64 * 1024 * 1024);

        while passed < symbol_read_limit && bitsreader.read_bits_8_8_unchecked(&mut b, symbol_bits)? {
            if prev != b {
                bitswriter.write_bits_32_8_unchecked(prev as u32, symbol_bits)?;
                if count > 1 {
                    bitswriter.write_bits_32_8_unchecked(prev as u32, symbol_bits)?;
                    v.push(count);
                }
                if count > max_count {
                    max_count = count
                }
                count = 0;
            }
            count += 1;
            passed += 1;
            prev = b;
        }
        
        bitswriter.write_bits_32_8_unchecked(prev as u32, symbol_bits)?;
        if count > 1 {
            bitswriter.write_bits_32_8_unchecked(prev as u32, symbol_bits)?;
            v.push(count)
        }
        bitswriter.flush_buffered_bits()?;

        let counts_records_starts = bitswriter.seek_byte(SeekFrom::Current(0))? as u32;
        // write counts:
        let count_field_bits = (max_count as f32).log2().floor() as u8 + 1;
        println!("count field bits: {count_field_bits}");
        for c in v {
            bitswriter.write_bits_32_8_unchecked(c, count_field_bits)?;
        }
        bitswriter.flush_buffered_bits()?;

        let counts_records_ends = o_byte_writer.seek(SeekFrom::Current(0))? as u32;

        o_byte_writer.seek(SeekFrom::Start(1))?;
        o_byte_writer.write_all(slice::from_ref(&count_field_bits))?;
        o_byte_writer.write_all(&passed.to_le_bytes())?;
        o_byte_writer.write_all(&counts_records_starts.to_le_bytes())?;
        o_byte_writer.write_all(&(counts_records_ends - counts_records_starts).to_le_bytes())?;

        Ok(passed)
    }

    fn decode<R, W>(
        i_byte_reader: &mut R,
        o_byte_writer: &mut W,
    ) -> Result<u32, IOError>
    where
        R: Read + Seek,
        W: Write
    {
        let mut onebyte_buf = [0u8];
        let mut fourbyte_buf = [0u8;4];

        i_byte_reader.read_exact(&mut onebyte_buf)?;
        let symbol_bits = onebyte_buf[0];
        i_byte_reader.read_exact(&mut onebyte_buf)?;
        let count_field_bits = onebyte_buf[0];
        i_byte_reader.read_exact(&mut fourbyte_buf)?;
        let n_symbols_encoded = u32::from_le_bytes(fourbyte_buf);
        i_byte_reader.read_exact(&mut fourbyte_buf)?;
        let counts_records_start = u32::from_le_bytes(fourbyte_buf);
        i_byte_reader.read_exact(&mut fourbyte_buf)?;
        let counts_records_bytes = u32::from_le_bytes(fourbyte_buf);

        let mut counts = vec![0u8; counts_records_bytes as usize];
        i_byte_reader.seek(SeekFrom::Start(counts_records_start as u64))?;
        i_byte_reader.read_exact(&mut counts)?;

        i_byte_reader.seek(SeekFrom::Start(1 + 1 + 4 + 4 + 4))?;

        let mut counts_byte_reader = counts.as_slice();

        let mut bitsreader_counts = BitsRead::create_with(&mut counts_byte_reader);
        let mut bitsreader_symbols = BitsRead::create_with(i_byte_reader);
        let mut bitswriter = BitsWrite::create_with(o_byte_writer);

        let mut prev = 0u8;
        let mut b = 0u8;

        if !bitsreader_symbols.read_bits_8_8_unchecked(&mut prev, symbol_bits)? {
            return Ok(0);
        }
        bitswriter.write_bits_32_8_unchecked(prev as u32, symbol_bits)?;

        let mut passed = 1u32;
        let mut count = 0u32;

        while passed < n_symbols_encoded && bitsreader_symbols.read_bits_8_8_unchecked(&mut b, symbol_bits)? {
            bitswriter.write_bits_32_8_unchecked(b as u32, symbol_bits)?;

            passed += 1;

            if b == prev {
                if !bitsreader_counts.read_bits_32_128_unchecked(&mut count, count_field_bits)? {
                    break;
                }
                let remainning = count - 2;

                for _ in 0..remainning {
                    bitswriter.write_bits_32_8_unchecked(b as u32, symbol_bits)?;
                }
                passed += remainning;
            }

            prev = b;
        }

        bitswriter.flush_buffered_bits()?;

        Ok(passed)
    }
}
